# -*- coding: utf-8 -*-
"""STRAINER_demo.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1z6hrLazoJQd5zMMhaxKy_wnIkLKOG_9b

## Demo code for STRAINER.
"""

!pip install torch torchvision numpy scipy scikit-image opencv-python matplotlib tqdm ipdb

import gdown
gdown.download_folder("https://drive.google.com/drive/folders/1n4YN3wClDM7kKXp4OPrcuqytwHDC_wFb?usp=sharing")

import torch
import torch.nn as nn
torch.manual_seed(1234)
import os, os.path as osp
import numpy as np
import skimage, skimage.transform, skimage.io, skimage.filters
import matplotlib as mpl
from matplotlib import pyplot as plt
from tqdm.autonotebook import tqdm
from collections import defaultdict, OrderedDict
import math
import cv2
from copy import deepcopy

import glob
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

class SineLayer(nn.Module):
    '''
        See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for
        discussion of omega_0.

        If is_first=True, omega_0 is a frequency factor which simply multiplies
        the activations before the nonlinearity. Different signals may require
        different omega_0 in the first layer - this is a hyperparameter.

        If is_first=False, then the weights will be divided by omega_0 so as to
        keep the magnitude of activations constant, but boost gradients to the
        weight matrix (see supplement Sec. 1.5)
    '''

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30, scale=10.0, init_weights=True):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        if init_weights:
            self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

class INR(nn.Module):
    def __init__(self, in_features, hidden_features,
                 hidden_layers,
                 out_features, outermost_linear=True,
                 first_omega_0=30, hidden_omega_0=30., scale=10.0,
                 pos_encode=False, sidelength=512, fn_samples=None,
                 use_nyquist=True, no_init=False):
        super().__init__()
        self.pos_encode = pos_encode
        self.nonlin = SineLayer

        self.net = []
        if hidden_layers != 0:
        # append first sine layer
            self.net.append(self.nonlin(in_features, hidden_features,
                                    is_first=True, omega_0=first_omega_0,
                                    scale=scale, init_weights=(not no_init)))
        hidden_layers = hidden_layers -1 if (hidden_layers > 0 and outermost_linear is True) else hidden_layers
        for i in range(hidden_layers):
            self.net.append(self.nonlin(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0,
                                      scale=scale, init_weights=(not no_init)))

        if outermost_linear or (hidden_layers == 0):
            dtype = torch.float
            final_linear = nn.Linear(hidden_features,
                                     out_features,
                                     dtype=dtype)

            if not no_init:
                with torch.no_grad():
                    const = np.sqrt(6/hidden_features)/max(hidden_omega_0, 1e-12)
                    final_linear.weight.uniform_(-const, const)

            self.net.append(final_linear)

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        if self.pos_encode:
            coords = self.positional_encoding(coords)

        output = self.net(coords)

        return output


class SharedINR(nn.Module):
    def __init__(self, in_features, hidden_features,
                 hidden_layers,
                 out_features, outermost_linear=True,
                 first_omega_0=30, hidden_omega_0=30., scale=10.0,
                 pos_encode=False, sidelength=512, fn_samples=None,
                 use_nyquist=True, shared_encoder_layers=None, num_decoders=None, no_init=False):

        super().__init__()
        assert shared_encoder_layers is not None, "Please mention shared_encoder_layers. Use 0 if none are shared"
        assert hidden_layers > shared_encoder_layers, "Total hidden layers must be greater than number of layers in shared encoder"
        self.shared_encoder_layers = shared_encoder_layers
        self.num_decoders = num_decoders

        self.encoderINR = INR(
            in_features=in_features,
            hidden_features=hidden_features,
            hidden_layers=self.shared_encoder_layers - 1, # input is a layer
            out_features=hidden_features,
            outermost_linear=False,
            first_omega_0=first_omega_0,
            hidden_omega_0=hidden_omega_0,
            scale=scale,
            pos_encode=pos_encode,
            sidelength=sidelength,
            fn_samples=fn_samples,
            use_nyquist=use_nyquist,
            no_init=no_init
        )

        self.num_decoder_layers = hidden_layers - self.shared_encoder_layers
        assert self.num_decoder_layers >= 1 , "Num decoder layers must be more than 1"
        self.decoderINRs = nn.ModuleList([
                                            INR(
                                                in_features=hidden_features,
                                                hidden_features=hidden_features,
                                                hidden_layers=self.num_decoder_layers - 1,
                                                out_features=out_features,
                                                outermost_linear=outermost_linear,
                                                first_omega_0=first_omega_0,
                                                hidden_omega_0=hidden_omega_0,
                                                scale=scale,
                                                pos_encode=pos_encode,
                                                sidelength=sidelength,
                                                fn_samples=fn_samples,
                                                use_nyquist=use_nyquist,
                                                no_init=no_init

                                            ) for i in range(self.num_decoders)])

    def forward(self, coords):
        encoded_features = self.encoderINR(coords)
        outputs = []
        for _idx, _decoder in enumerate(self.decoderINRs):
            output = _decoder(encoded_features)
            outputs.append(output)

        return outputs

    def load_encoder_weights_from(self, fellow_model):
        self.encoderINR.load_state_dict(deepcopy(fellow_model.encoderINR.state_dict()))

    def load_weights_from_file(self, file, key="encoderINR"):
        weights = torch.load(file)
        self.encoderINR.load_state_dict(deepcopy(weights['encoder_weights']))

IMG_SIZE = (178,178)
POS_ENCODE = False

config = defaultdict()
config['epochs']=2100
config['epochs_train_strainer'] = 5001
config['learning_rate'] = 1e-4
config['plot_every'] = 500
config['image_size'] = IMG_SIZE

# INR params
config['num_layers'] = 6
config['hidden_features'] = 256
config['in_channels'] = 2
config['out_channels'] = 3
config['shared_encoder_layers']=  5
config['num_decoders'] = 1
config['nonlin'] = 'siren'

TRAINING_PATH = "./STRAINER_DATA/train"
TESTING_PATH = "./STRAINER_DATA/test"

def get_celeba_sample_v1(path, take=1):
    files = sorted(glob.glob(osp.join(path, "*"))) #deterministic
    assert len(files) != 0, "Make sure path if correct. No files in path"
    if take != -1:
        subfiles = files[:take]
    else:
        subfiles = files
    images = []
    for filename in subfiles:
        data = np.load(filename)
        data = (data - data.min())/(data.max() - data.min())
        print(data.shape)
        images.append(cv2.resize(data, IMG_SIZE, interpolation=cv2.INTER_AREA))

    return images

def get_coords(H, W, device=torch.device('cuda'), T=None):
    if T is None:
        x = torch.linspace(-1, 1, W).to(device)
        y = torch.linspace(-1, 1, H).to(device)
        X, Y = torch.meshgrid(x, y, indexing='xy')
        coords = torch.hstack((X.reshape(-1, 1), Y.reshape(-1, 1)))[None, ...]
    else:
        X, Y, Z = np.meshgrid(np.linspace(-1, 1, W),
                              np.linspace(-1, 1, H),
                              np.linspace(-1, 1, T))
        coords = np.hstack((X.reshape(-1, 1),
                            Y.reshape(-1, 1),
                            Z.reshape(-1, 1)))
        coords = torch.tensor(coords.astype(np.float32)).to(device)
    return coords

coords = get_coords(*IMG_SIZE, device=device)
print(coords.min(), coords.max())

def convert_to_tensor(images, output_format="nhwc"):
    im_tensors = []
    for im in images:
        if output_format == "nchw":
            im_tensors.append(torch.from_numpy(im.transpose(2,0,1))) # rgb
        else:
            im_tensors.append(torch.from_numpy(im))

    return torch.stack(im_tensors).float().to(device)

train_images = get_celeba_sample_v1(TRAINING_PATH, take=10)
im_tensor_train = convert_to_tensor(train_images, "nhwc")
images = get_celeba_sample_v1(TESTING_PATH, take=1)
im_tensors = convert_to_tensor(images, "nhwc")

print("train data")
fig, axs= plt.subplots(2,5)
for i in range(5):
  axs[0,i].imshow(train_images[2*i])
  axs[1,i].imshow(train_images[2*i+1])
plt.show()

print("Test images")
plt.figure()
plt.imshow(images[0])
plt.show()

im_tensor_train.shape

data_dict_1 = {'image_size':IMG_SIZE, 'gt':im_tensor_train[1].reshape(1, -1, 3)}
data_dict_train_strainer = {'image_size':IMG_SIZE, 'gt':[x.reshape(1, -1, 3) for x in im_tensor_train]}
data_dict_test = {'image_size':IMG_SIZE, 'gt':im_tensors[0].reshape(1, -1, 3)}

def fit_inr(coords, data, model, optim, config={}, mlogger=None, name=None):
    assert name is not None, "`name` must be provided as metric logger needs it"
    gt_tensor = data['gt']

    best_loss = np.inf
    best_epoch = 0

    tbar = tqdm(range(config['epochs']))
    psnr_vals = []
    for epoch in tbar:
        outputs = model(coords) # 10 x 1 x (HxW) x 3
        output = outputs[0] if isinstance(outputs, list) else outputs
        n, p, ldim = output.shape
        loss = ((output - gt_tensor)**2).mean() # works per image
        optim.zero_grad()
        loss.backward()
        optim.step()

        psnr = -10*torch.log10(loss) #  config['image_size'][0],  config['image_size'][1])
        psnr_vals.append(float(psnr))

        tbar.set_description(f"Iter {epoch}/{config['epochs']} Loss = {loss.item():6f} PSNR = {psnr:.4f}")
        tbar.refresh()

    return {
        "psnr" : psnr_vals,
    }

def shared_encoder_training(coords, data, model, optim, config={}, mlogger=None, name=None):
    assert name is not None, "`name` must be provided as metric logger needs it"
    gt_tensor = data['gt']

    best_loss = np.inf
    best_epoch = 0

    tbar = tqdm(range(config['epochs']))
    psnr_vals = []

    for epoch in tbar:
        outputs = model(coords) # 10 x 1 x (HxW) x 3
        stacked_outputs = torch.stack(outputs, dim=0)
        stacked_gt = torch.stack(gt_tensor, dim=0)
        loss = ((stacked_outputs - stacked_gt)**2).mean(dim=[1,2,3]).sum()

        optim.zero_grad()
        loss.backward()
        optim.step()

        tbar.set_description(f"Iter {epoch}/{config['epochs']} Loss = {loss.item():6f}")
        tbar.refresh()

    return {
        "psnr" : psnr_vals,
    }

"""## Train a Vanilla Siren on some "training" image"""

inr_siren_vanilla1_random_train = SharedINR(in_features=config['in_channels'],
                                hidden_features=config['hidden_features'], hidden_layers=config['num_layers'],
                                shared_encoder_layers = config['shared_encoder_layers'],
                                num_decoders=config['num_decoders'],
                                out_features=config['out_channels']).to(device)
optim_siren_vanilla1_random_train = torch.optim.Adam(lr=config['learning_rate'], params=inr_siren_vanilla1_random_train.parameters())

ret_inr_siren_vanilla1_random_train = fit_inr(coords=coords, data=data_dict_1,
                                                model=inr_siren_vanilla1_random_train,
                                              optim=optim_siren_vanilla1_random_train,
                                              config=config, mlogger=None,name="random_vanilla_train")

"""## Fit Unseen image to randomly initialized Siren"""

inr_siren_vanilla1 = SharedINR(in_features=config['in_channels'],
                                hidden_features=config['hidden_features'], hidden_layers=config['num_layers'],
                                shared_encoder_layers = config['shared_encoder_layers'],
                                num_decoders=config['num_decoders'],
                                out_features=config['out_channels']).to(device)
optim_siren_vanilla1 = torch.optim.Adam(lr=config['learning_rate'], params=inr_siren_vanilla1.parameters())

ret_inr_siren_vanilla1 = fit_inr(coords=coords, data=data_dict_test,
                                                model=inr_siren_vanilla1,
                                              optim=optim_siren_vanilla1,
                                              config=config, mlogger=None,name="inr_siren_vanilla")

"""## Fit Unseen image to a fine-tuned Siren. (Finetuning using the random train image weights)"""

inr_siren_vanilla1_finetuned = SharedINR(in_features=config['in_channels'],
                                hidden_features=config['hidden_features'], hidden_layers=config['num_layers'],
                                shared_encoder_layers = config['shared_encoder_layers'],
                                num_decoders=config['num_decoders'],
                                out_features=config['out_channels']).to(device)
inr_siren_vanilla1_finetuned.load_state_dict(deepcopy(inr_siren_vanilla1_random_train.state_dict()))
optim_siren_vanilla1_finetuned = torch.optim.Adam(lr=config['learning_rate'], params=inr_siren_vanilla1_finetuned.parameters())

ret_inr_siren_finetuned = fit_inr(coords=coords, data=data_dict_test,
                                                model=inr_siren_vanilla1_finetuned,
                                              optim=optim_siren_vanilla1_finetuned,
                                              config=config, mlogger=None,name="inr_siren_finetuned")

"""## Strainer 1-decoder : Only use encoder layers as initialization instead of full model"""

inr_strainer_1decoder = SharedINR(in_features=config['in_channels'],
                                hidden_features=config['hidden_features'], hidden_layers=config['num_layers'],
                                shared_encoder_layers = config['shared_encoder_layers'],
                                num_decoders=config['num_decoders'],
                                out_features=config['out_channels']).to(device)
inr_strainer_1decoder.load_encoder_weights_from(inr_siren_vanilla1_random_train)
optim_siren_strainer1decoder = torch.optim.Adam(lr=config['learning_rate'], params=inr_strainer_1decoder.parameters())

ret_strainer1decoder = fit_inr(coords=coords, data=data_dict_test,
                                                model=inr_strainer_1decoder,
                                              optim=optim_siren_strainer1decoder,
                                              config=config, mlogger=None,name="strainer_encoder_only_1decoder")

"""## [Our method ] : Shared Encoder Training with Strainer"""

inr_strainer_10decoders_train = SharedINR(in_features=config['in_channels'],
                                hidden_features=config['hidden_features'], hidden_layers=config['num_layers'],
                                shared_encoder_layers = config['shared_encoder_layers'],
                                num_decoders=10,
                                out_features=config['out_channels']).to(device)
optim_siren_strainer10decoder_train = torch.optim.Adam(lr=config['learning_rate'], params=inr_strainer_10decoders_train.parameters())

config_train = deepcopy(config)
config_train['epochs'] = config['epochs_train_strainer']
ret_strainer10decoder_train = shared_encoder_training(coords=coords, data=data_dict_train_strainer,
                                                model=inr_strainer_10decoders_train,
                                              optim=optim_siren_strainer10decoder_train,
                                              config=config_train, mlogger=None,name="strainer_encoder_only_10decoder")

"""## [Our solution]: Fit Unseen Image with Strainer. Shared encoder trained on multiple similar images"""

inr_strainer_test = SharedINR(in_features=config['in_channels'],
                                hidden_features=config['hidden_features'], hidden_layers=config['num_layers'],
                                shared_encoder_layers = config['shared_encoder_layers'],
                                num_decoders=config['num_decoders'],
                                out_features=config['out_channels']).to(device)
inr_strainer_test.load_encoder_weights_from(inr_strainer_10decoders_train)
optim_siren_strainer_test = torch.optim.Adam(lr=config['learning_rate'], params=inr_strainer_test.parameters())

ret_strainer_test = fit_inr(coords=coords, data=data_dict_test,
                                                model=inr_strainer_test,
                                              optim=optim_siren_strainer_test,
                                              config=config, mlogger=None,name="strainer_test")

results = {
    "Siren" : ret_inr_siren_vanilla1,
    "Siren-fientuned":ret_inr_siren_finetuned,
    "Strainer(encoder learned from 1 image)" : ret_strainer1decoder,
    "Strainer (proposed, learned from 10 images)" : ret_strainer_test
}

plt.figure()
for key, ret in results.items():
    plt.plot([float(x) for x in ret['psnr']], label=key)

plt.legend(loc="lower right")
plt.tight_layout()
plt.xlabel("Iterations")
plt.ylabel("Peak signal to noise ratio (PSNR)")
plt.tight_layout()
plt.show()

def calc_model_params(model):
  num_params = 0
  for key, param in model.named_parameters():
    param_count = torch.numel(param)
    # print(f"{key=} {param_count=}")
    num_params += param_count

  return num_params

names = ['Siren', 'Siren Finetuned', 'Strainer(encoder-1decoder)', 'Strainer-10(proposed) test']
models = [inr_siren_vanilla1, inr_siren_vanilla1_finetuned, inr_strainer_1decoder, inr_strainer_test]

print("=="*50)
for name , model in zip(names, models):
  param_in_model = calc_model_params(model)
  print(f"{name} : {param_in_model} parameters")

print("=="*50)

